{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os.path as ops\n",
    "import numpy as np\n",
    "import torch\n",
    "import cv2\n",
    "import time\n",
    "import os\n",
    "import matplotlib.pylab as plt\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "sys.path.append('..')\n",
    "from dataset.dataset_utils import TUSIMPLE\n",
    "from Lanenet.model2 import Lanenet\n",
    "from utils.evaluation import gray_to_rgb_emb, process_instance_embedding, video_to_clips"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "success\n"
     ]
    }
   ],
   "source": [
    "model_path = '../TUSIMPLE/Lanenet_output/lanenet_epoch_39_batch_8.model'\n",
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "LaneNet_model = Lanenet(2, 4)\n",
    "LaneNet_model.load_state_dict(torch.load(model_path))\n",
    "LaneNet_model.to(device)\n",
    "print('success')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the Test Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test use TUSIMPLE test dataset ``clips`` and ``test_label.json`` and write the predit result in ``test_tasks_0627.json`` use the evaluation from TUSIMPLE dataset in ``utils/lane.py``"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "write lanes and run_time to ``pred_result.json``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_json_path = '../TUSIMPLE/test_set/test_tasks_0627.json'\n",
    "json_pred = [json.loads(line) for line in open(pred_json_path).readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2782/2782 [33:43<00:00,  1.37it/s]\n"
     ]
    }
   ],
   "source": [
    "all_time_forward = []\n",
    "all_time_clustering = []\n",
    "for i, sample in enumerate(tqdm(json_pred)):\n",
    "    h_samples = sample['h_samples']\n",
    "    lanes = sample['lanes']\n",
    "    run_time = sample['run_time']\n",
    "    raw_file = sample['raw_file']\n",
    "    img_path = ops.join('../TUSIMPLE/test_set', raw_file)\n",
    "    # read the image\n",
    "    gt_img_org = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)\n",
    "    org_shape = gt_img_org.shape\n",
    "    gt_image = cv2.resize(gt_img_org, dsize=(512, 256), interpolation=cv2.INTER_LINEAR)\n",
    "    gt_image = gt_image / 127.5 - 1.0\n",
    "    gt_image = torch.tensor(gt_image, dtype=torch.float)\n",
    "    gt_image = np.transpose(gt_image, (2, 0, 1))\n",
    "    # Go through the network\n",
    "    time_start=time.time()\n",
    "    binary_final_logits, instance_embedding = LaneNet_model(gt_image.unsqueeze(0).cuda())\n",
    "    # binary_final_logits = binary_final_logits.cpu()\n",
    "    # instance_embedding = instance_embedding.cpu()\n",
    "    time_end=time.time()\n",
    "    # Get the final embedding image\n",
    "    binary_img = torch.argmax(binary_final_logits, dim=1).squeeze().cpu().numpy()\n",
    "    binary_img[0:50,:] = 0\n",
    "    clu_start = time.time()\n",
    "    rbg_emb, cluster_result = process_instance_embedding(instance_embedding.cpu(), binary_img,\n",
    "                                                             distance=1.5, lane_num=4)\n",
    "    clu_end = time.time()\n",
    "    cluster_result = cv2.resize(cluster_result, dsize=(org_shape[1], org_shape[0]), \n",
    "                                interpolation=cv2.INTER_NEAREST)\n",
    "    elements = np.unique(cluster_result)\n",
    "    for line_idx in elements:\n",
    "        if line_idx == 0:\n",
    "            continue\n",
    "        else:\n",
    "            mask = (cluster_result == line_idx)\n",
    "            select_mask = mask[h_samples]\n",
    "            row_result = []\n",
    "            for row in range(len(h_samples)):\n",
    "                col_indexes = np.nonzero(select_mask[row])[0]\n",
    "                if len(col_indexes) == 0:\n",
    "                    row_result.append(-2)\n",
    "                else:\n",
    "                    row_result.append(int(col_indexes.min() + (col_indexes.max()-col_indexes.min())/2))\n",
    "            json_pred[i]['lanes'].append(row_result)\n",
    "            json_pred[i]['run_time'] = time_end-time_start\n",
    "            all_time_forward.append(time_end-time_start)\n",
    "            all_time_clustering.append(clu_end-clu_start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The Forward pass time for one image is: 50.020827770233154ms\n",
      "The Clustering time for one image is: 619.1658550898234ms\n",
      "The total time for one image is: 669.1866828600565ms\n",
      "The speed for foreard pass is: 19.991672360829842fps\n",
      "The speed for foreard pass is: 1.6150761411979484fps\n"
     ]
    }
   ],
   "source": [
    "forward_avg = np.sum(all_time_forward[500:2000])/1500\n",
    "cluster_avg = np.sum(all_time_clustering[500:2000])/1500\n",
    "\n",
    "print('The Forward pass time for one image is: {}ms'.format(forward_avg*1000))\n",
    "print('The Clustering time for one image is: {}ms'.format(cluster_avg*1000))\n",
    "print('The total time for one image is: {}ms'.format((cluster_avg+forward_avg)*1000))\n",
    "\n",
    "print('The speed for foreard pass is: {}fps'.format(1/forward_avg))\n",
    "print('The speed for foreard pass is: {}fps'.format(1/cluster_avg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../TUSIMPLE/pred.json', 'w') as f:\n",
    "    for res in json_pred:\n",
    "        json.dump(res, f)\n",
    "        f.write('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation using TUSIMPLE script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.lane import LaneEval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = LaneEval.bench_one_submit('../TUSIMPLE/pred.json',\n",
    "                         '../TUSIMPLE/test_set/test_label.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{\"name\":\"Accuracy\",\"value\":0.9426404008763787,\"order\":\"desc\"},{\"name\":\"FP\",\"value\":0.14701653486700209,\"order\":\"asc\"},{\"name\":\"FN\",\"value\":0.0695542774982027,\"order\":\"asc\"}]\n"
     ]
    }
   ],
   "source": [
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation for aug result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "success\n"
     ]
    }
   ],
   "source": [
    "model_path = '../TUSIMPLE/Lanenet_output/lanenet_epoch_39_batch_8_AUG.model'\n",
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "LaneNet_model = Lanenet(2, 4)\n",
    "LaneNet_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))\n",
    "LaneNet_model.to(device)\n",
    "print('success')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_json_path = '../TUSIMPLE/test_set/test_tasks_0627.json'\n",
    "json_pred = [json.loads(line) for line in open(pred_json_path).readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2782/2782 [1:54:29<00:00,  2.47s/it]  \n"
     ]
    }
   ],
   "source": [
    "all_time_forward = []\n",
    "all_time_clustering = []\n",
    "for i, sample in enumerate(tqdm(json_pred)):\n",
    "    h_samples = sample['h_samples']\n",
    "    lanes = sample['lanes']\n",
    "    run_time = sample['run_time']\n",
    "    raw_file = sample['raw_file']\n",
    "    img_path = ops.join('../TUSIMPLE/test_set', raw_file)\n",
    "    # read the image\n",
    "    gt_img_org = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)\n",
    "    org_shape = gt_img_org.shape\n",
    "    gt_image = cv2.resize(gt_img_org, dsize=(512, 256), interpolation=cv2.INTER_LINEAR)\n",
    "    gt_image = gt_image / 127.5 - 1.0\n",
    "    gt_image = torch.tensor(gt_image, dtype=torch.float)\n",
    "    gt_image = np.transpose(gt_image, (2, 0, 1))\n",
    "    # Go through the network\n",
    "    time_start=time.time()\n",
    "    binary_final_logits, instance_embedding = LaneNet_model(gt_image.unsqueeze(0))\n",
    "    # binary_final_logits = binary_final_logits.cpu()\n",
    "    # instance_embedding = instance_embedding.cpu()\n",
    "    time_end=time.time()\n",
    "    # Get the final embedding image\n",
    "    binary_img = torch.argmax(binary_final_logits, dim=1).squeeze().cpu().numpy()\n",
    "    binary_img[0:50,:] = 0\n",
    "    clu_start = time.time()\n",
    "    rbg_emb, cluster_result = process_instance_embedding(instance_embedding.cpu(), binary_img,\n",
    "                                                             distance=1.5, lane_num=4)\n",
    "    clu_end = time.time()\n",
    "    cluster_result = cv2.resize(cluster_result, dsize=(org_shape[1], org_shape[0]), \n",
    "                                interpolation=cv2.INTER_NEAREST)\n",
    "    elements = np.unique(cluster_result)\n",
    "    for line_idx in elements:\n",
    "        if line_idx == 0:\n",
    "            continue\n",
    "        else:\n",
    "            mask = (cluster_result == line_idx)\n",
    "            select_mask = mask[h_samples]\n",
    "            row_result = []\n",
    "            for row in range(len(h_samples)):\n",
    "                col_indexes = np.nonzero(select_mask[row])[0]\n",
    "                if len(col_indexes) == 0:\n",
    "                    row_result.append(-2)\n",
    "                else:\n",
    "                    row_result.append(int(col_indexes.min() + (col_indexes.max()-col_indexes.min())/2))\n",
    "            json_pred[i]['lanes'].append(row_result)\n",
    "            json_pred[i]['run_time'] = time_end-time_start\n",
    "            all_time_forward.append(time_end-time_start)\n",
    "            all_time_clustering.append(clu_end-clu_start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../TUSIMPLE/pred_aug.json', 'w') as f:\n",
    "    for res in json_pred:\n",
    "        json.dump(res, f)\n",
    "        f.write('\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.lane import LaneEval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = LaneEval.bench_one_submit('../TUSIMPLE/pred_aug.json',\n",
    "                         '../TUSIMPLE/test_set/test_label.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{\"name\":\"Accuracy\",\"value\":0.9471710143439054,\"order\":\"desc\"},{\"name\":\"FP\",\"value\":0.15088066139468007,\"order\":\"asc\"},{\"name\":\"FN\",\"value\":0.06242511382698289,\"order\":\"asc\"}]\n"
     ]
    }
   ],
   "source": [
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
